import torch
import numpy as np
from data.load_llff import load_llff_data, poses2pytorch3dcam

class LlffDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', epoch_len=1000, llffhold=8, **llff_args):
        images, poses, bds, self.test_poses, _ = load_llff_data(**llff_args)
        i_test = np.arange(images.shape[0])[::llffhold]
        train_ids = [i for i in range(len(images)) if i not in i_test]
        self.split = split
        self.epoch_len = epoch_len
        if split == 'train':
            self.target_images = torch.from_numpy(images[train_ids]).float()
            self.target_cameras = poses2pytorch3dcam(self.target_images, poses[train_ids])
            self.epoch_len = epoch_len
        else:
            self.target_images = torch.from_numpy(images[i_test]).float()
            self.target_cameras = poses2pytorch3dcam(self.target_images, poses[i_test]) + poses2pytorch3dcam(self.target_images, self.test_poses)
            self.epoch_len = len(self.target_cameras)
    
    def __getitem__(self, i):
        return {
            'target_camera' : self.target_cameras[i % len(self.target_cameras)], 
            'target_image' : self.target_images[i % len(self.target_cameras)] if (self.split == 'train' or i < len(self.target_images)) else None, 
            # 'target_silhouette' : self.target_silhouettes[i % len(self.target_images)]
        }
    
    def __len__(self):
        return self.epoch_len